"""
Phase 5: Eurozone Stress Test — Demographics Under a Fixed Regime
=================================================================
1. Eurozone-only CA regressions
2. Eurozone vs OECD floaters comparison
3. Within-EMU demographic dispersion → CA dispersion
4. Counterfactual: what regime would EMU members choose if unconstrained?
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

# OECD members (as of 2024, 38 countries)
OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}

# Eurozone members with join years
EUROZONE_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}
EUROZONE_ISO3 = set(EUROZONE_JOIN.keys())


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    return result


def run_logit(df, y_var, x_vars, label):
    """Pooled logit with Hessian SEs (manual implementation via BFGS)."""
    from scipy.optimize import minimize
    from scipy import stats as sp_stats

    cols = [y_var] + x_vars + ['iso3']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    y = sub[y_var].values.astype(float)
    X = np.column_stack([np.ones(len(sub)), sub[x_vars].values.astype(float)])
    n, k = X.shape

    if y.sum() < 5 or (1 - y).sum() < 5:
        print(f"  {label}: insufficient outcome variation (events={y.sum():.0f}), skipping")
        return None

    # Standardize X for numerical stability (except intercept)
    x_means = X[:, 1:].mean(axis=0)
    x_stds = X[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_std = X.copy()
    X_std[:, 1:] = (X[:, 1:] - x_means) / x_stds

    def neg_log_likelihood(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))

    def gradient(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        return -X_std.T @ (y - p)

    beta0 = np.zeros(k)

    try:
        result = minimize(neg_log_likelihood, beta0, jac=gradient,
                         method='BFGS', options={'maxiter': 1000, 'gtol': 1e-6})
        if not result.success:
            print(f"  {label}: logit optimization did not converge, using result anyway")

        beta_std = result.x

        # Transform back to original scale
        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)

        # Hessian for standard errors
        z = X @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        W = p * (1 - p)
        H = X.T @ (X * W[:, None])

        try:
            V = np.linalg.inv(H)
            se = np.sqrt(np.diag(V))
        except np.linalg.LinAlgError:
            se = np.full(k, np.nan)

        t_stats = beta / se
        pvalues = 2 * (1 - sp_stats.norm.cdf(np.abs(t_stats)))

        # Pseudo-R² (McFadden)
        ll_model = -neg_log_likelihood(beta_std)
        p_bar = y.mean()
        ll_null = n * (p_bar * np.log(p_bar + 1e-12) + (1 - p_bar) * np.log(1 - p_bar + 1e-12))
        pseudo_r2 = 1 - ll_model / ll_null if ll_null != 0 else 0

        # Marginal effects at mean
        z_mean = X.mean(axis=0) @ beta
        p_mean = 1 / (1 + np.exp(-z_mean))
        mfx = beta[1:] * p_mean * (1 - p_mean)

    except Exception as e:
        print(f"  {label}: logit failed ({e}), skipping")
        return None

    res = {
        'model': label,
        'n_obs': n,
        'n_countries': sub['iso3'].nunique(),
        'r_squared': pseudo_r2,
        'rho': 0.0,
    }

    print(f"\n  {label} (N={n}, Pseudo-R²={pseudo_r2:.4f}) [Logit]")
    for i, name in enumerate(x_vars):
        sig = stars(pvalues[i + 1])
        print(f"    {name:30s} β={beta[i+1]:8.4f} (se={se[i+1]:.4f}) {sig}  "
              f"[MFX={mfx[i]:.5f}]")
        res[f'{name}_coef'] = mfx[i]
        res[f'{name}_se'] = se[i + 1] * p_mean * (1 - p_mean)
        res[f'{name}_p'] = pvalues[i + 1]

    return res, beta, se, pvalues


def write_table(results, filename, title):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── 1. Eurozone-Only CA Regressions ────────────────────────────────────

def eurozone_ca(df):
    """CA regressions on eurozone-only subsample."""
    print("\n" + "=" * 60)
    print("1. EUROZONE-ONLY CA REGRESSIONS")
    print("=" * 60)

    # Filter to eurozone members after their join year
    ez_rows = []
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
        ez_rows.append(df[mask])
    ez_df = pd.concat(ez_rows, ignore_index=True)
    print(f"  Eurozone sample: {len(ez_df)} obs, {ez_df['iso3'].nunique()} countries")
    print(f"  Year range: {ez_df['year'].min()}-{ez_df['year'].max()}")

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']
    x_demo = ['Z_1', 'Z_2', 'Z_3']

    results = []

    # M1: Z → ca_gdp (eurozone only)
    r = run_panel_gls(ez_df, 'ca_gdp', x_demo + controls, 'EZ: Z')
    if r: results.append(r)

    # M2: Age → ca_gdp (eurozone only)
    r = run_panel_gls(ez_df, 'ca_gdp',
                      ['old_dep', 'youth_dep'] + controls, 'EZ: Age')
    if r: results.append(r)

    # M3: Z → ca_gdp (full sample, for comparison)
    full_post99 = df[df['year'] >= 1999].copy()
    r = run_panel_gls(full_post99, 'ca_gdp', x_demo + controls, 'Full post-99')
    if r: results.append(r)

    write_table(results, "eurozone_ca.md",
                "Current Account Regressions: Eurozone Subsample")

    return ez_df


# ── 2. Eurozone vs OECD Floaters ──────────────────────────────────────

def eurozone_vs_floaters(df, ez_df):
    """Compare CA regressions: eurozone vs OECD floaters."""
    print("\n" + "=" * 60)
    print("2. EUROZONE vs. OECD FLOATERS")
    print("=" * 60)

    # OECD floaters: OECD members that are not eurozone AND year >= 1999
    floater_df = df[(df['oecd_floater'] == 1) & (df['year'] >= 1999)].copy()
    print(f"  OECD floaters: {floater_df['iso3'].nunique()} countries, "
          f"{len(floater_df)} obs")
    print(f"  Floater countries: {sorted(floater_df['iso3'].unique())}")

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']
    x_demo = ['Z_1', 'Z_2', 'Z_3']

    results = []

    # M1: Eurozone
    r = run_panel_gls(ez_df, 'ca_gdp', x_demo + controls, 'Eurozone')
    if r: results.append(r)

    # M2: OECD floaters
    r = run_panel_gls(floater_df, 'ca_gdp', x_demo + controls, 'OECD Floaters')
    if r: results.append(r)

    # M3: Eurozone with age
    r = run_panel_gls(ez_df, 'ca_gdp',
                      ['old_dep', 'youth_dep'] + controls, 'EZ: Age')
    if r: results.append(r)

    # M4: OECD floaters with age
    r = run_panel_gls(floater_df, 'ca_gdp',
                      ['old_dep', 'youth_dep'] + controls, 'Float: Age')
    if r: results.append(r)

    write_table(results, "eurozone_vs_floaters.md",
                "Eurozone vs. OECD Floaters: CA/GDP Regressions (Post-1999)")


# ── 3. Within-EMU Dispersion ──────────────────────────────────────────

def emu_dispersion(df):
    """Within-EMU demographic dispersion → CA dispersion."""
    print("\n" + "=" * 60)
    print("3. WITHIN-EMU DEMOGRAPHIC DISPERSION → CA DISPERSION")
    print("=" * 60)

    # Filter to eurozone members after their join year
    ez_rows = []
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
        ez_rows.append(df[mask])
    ez_df = pd.concat(ez_rows, ignore_index=True)

    # --- 3a. Time-series: cross-sectional dispersion ---
    print("\n  --- 3a. Cross-sectional dispersion over time ---")

    disp_data = []
    for year in sorted(ez_df['year'].unique()):
        yr_df = ez_df[ez_df['year'] == year]
        if len(yr_df) < 5:
            continue
        z1_std = yr_df['Z_1'].std()
        ca_std = yr_df['ca_gdp'].std()
        z1_mean = yr_df['Z_1'].mean()
        ca_mean = yr_df['ca_gdp'].mean()
        disp_data.append({
            'year': year,
            'demo_dispersion': z1_std,
            'ca_dispersion': ca_std,
            'z1_mean': z1_mean,
            'ca_mean': ca_mean,
            'n_countries': len(yr_df),
        })

    disp_df = pd.DataFrame(disp_data)
    if len(disp_df) < 5:
        print("  Insufficient years for dispersion analysis")
        return

    # Simple OLS: ca_dispersion = a + b * demo_dispersion
    from scipy import stats as sp_stats
    slope, intercept, r_value, p_value, std_err = sp_stats.linregress(
        disp_df['demo_dispersion'], disp_df['ca_dispersion'])
    print(f"  OLS: ca_disp = {intercept:.3f} + {slope:.3f} * demo_disp")
    print(f"  R² = {r_value**2:.4f}, p = {p_value:.4f}, N = {len(disp_df)} years")

    # Correlation
    corr = disp_df[['demo_dispersion', 'ca_dispersion']].corr().iloc[0, 1]
    print(f"  Correlation(demo_disp, ca_disp) = {corr:.4f}")

    # --- 3b. Within-EMU: country Z deviations → country CA deviations ---
    print("\n  --- 3b. Country deviations from EMU mean ---")

    # Compute EMU-year means
    emu_means = ez_df.groupby('year')[['Z_1', 'Z_2', 'Z_3', 'ca_gdp']].mean()
    emu_means.columns = [f'{c}_emu_mean' for c in emu_means.columns]
    ez_dev = ez_df.merge(emu_means, on='year', how='left')
    ez_dev['Z_1_dev'] = ez_dev['Z_1'] - ez_dev['Z_1_emu_mean']
    ez_dev['Z_2_dev'] = ez_dev['Z_2'] - ez_dev['Z_2_emu_mean']
    ez_dev['Z_3_dev'] = ez_dev['Z_3'] - ez_dev['Z_3_emu_mean']
    ez_dev['ca_dev'] = ez_dev['ca_gdp'] - ez_dev['ca_gdp_emu_mean']

    results = []

    # PanelGLS: ca_dev = f(Z_dev)
    r = run_panel_gls(ez_dev, 'ca_dev',
                      ['Z_1_dev', 'Z_2_dev', 'Z_3_dev'],
                      'CA dev ~ Z dev')
    if r: results.append(r)

    # With controls
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']
    avail_controls = [c for c in controls if c in ez_dev.columns]
    r = run_panel_gls(ez_dev, 'ca_dev',
                      ['Z_1_dev', 'Z_2_dev', 'Z_3_dev'] + avail_controls,
                      'CA dev ~ Z dev + ctrl')
    if r: results.append(r)

    # --- Write dispersion table ---
    lines = ["# Within-EMU Demographic Dispersion and CA Dispersion\n"]

    # Part A: time-series dispersion
    lines.append("## A. Cross-Sectional Dispersion Over Time\n")
    lines.append(f"OLS: CA_dispersion = {intercept:.3f} + {slope:.3f}{stars(p_value)} "
                 f"* Demo_dispersion")
    lines.append(f"  - R² = {r_value**2:.4f}, p = {p_value:.4f}, N = {len(disp_df)} years")
    lines.append(f"  - Correlation = {corr:.4f}\n")

    lines.append("| Year | Demo Dispersion (σ Z₁) | CA Dispersion (σ CA/GDP) | N countries |")
    lines.append("|---:|---:|---:|---:|")
    for _, row in disp_df.iterrows():
        lines.append(f"| {int(row['year'])} | {row['demo_dispersion']:.4f} | "
                     f"{row['ca_dispersion']:.2f} | {int(row['n_countries'])} |")

    # Part B: deviation regressions
    if results:
        lines.append("\n## B. Country Deviations from EMU Mean\n")

        all_vars = []
        for r in results:
            for k in r:
                if k.endswith('_coef'):
                    vname = k.replace('_coef', '')
                    if vname not in all_vars:
                        all_vars.append(vname)

        model_labels = [r['model'] for r in results]
        header = "| Variable | " + " | ".join(model_labels) + " |"
        sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
        lines.append(header)
        lines.append(sep)

        for var in all_vars:
            coef_row = f"| {var} |"
            se_row = "| |"
            for r in results:
                if f'{var}_coef' in r:
                    c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                    coef_row += f" {c} |"
                    se_row += f" {s} |"
                else:
                    coef_row += " |"
                    se_row += " |"
            lines.append(coef_row)
            lines.append(se_row)

        lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
        n_row = "| N |"
        r2_row = "| R² |"
        nc_row = "| Countries |"
        for r in results:
            n_row += f" {r['n_obs']} |"
            r2_row += f" {r['r_squared']:.4f} |"
            nc_row += f" {r['n_countries']} |"
        lines.append(n_row)
        lines.append(r2_row)
        lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / "emu_dispersion.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── 4. Counterfactual: What Regime Would EMU Members Choose? ──────────

def emu_counterfactual(df):
    """Predict regime choice for EMU members using non-EMU OECD logit."""
    print("\n" + "=" * 60)
    print("4. COUNTERFACTUAL: WHAT REGIME WOULD EMU MEMBERS CHOOSE?")
    print("=" * 60)

    from scipy.optimize import minimize
    from scipy import stats as sp_stats

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']
    x_vars = ['Z_1', 'Z_2', 'Z_3'] + controls

    # Training sample: OECD non-eurozone, post-1999
    train_df = df[(df['iso3'].isin(OECD)) &
                  (~df['iso3'].isin(EUROZONE_ISO3)) &
                  (df['year'] >= 1999)].copy()

    # Ensure is_peg exists
    if 'is_peg' not in train_df.columns:
        train_df['is_peg'] = (train_df['regime_3cat'] == 1).astype(float)

    cols = ['is_peg'] + x_vars + ['iso3']
    train_sub = train_df[cols].dropna()
    print(f"  Training sample (OECD non-EZ): {len(train_sub)} obs, "
          f"{train_sub['iso3'].nunique()} countries")
    print(f"  Pegs in training: {train_sub['is_peg'].sum():.0f} / {len(train_sub)}")

    y_train = train_sub['is_peg'].values.astype(float)
    X_train = np.column_stack([np.ones(len(train_sub)),
                                train_sub[x_vars].values.astype(float)])
    n, k = X_train.shape

    if y_train.sum() < 5 or (1 - y_train).sum() < 5:
        print("  Insufficient outcome variation in training sample, skipping")
        return

    # Standardize
    x_means = X_train[:, 1:].mean(axis=0)
    x_stds = X_train[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_train_std = X_train.copy()
    X_train_std[:, 1:] = (X_train[:, 1:] - x_means) / x_stds

    def neg_log_likelihood(beta):
        z = X_train_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y_train * np.log(p) + (1 - y_train) * np.log(1 - p))

    def gradient(beta):
        z = X_train_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        return -X_train_std.T @ (y_train - p)

    beta0 = np.zeros(k)
    try:
        opt_result = minimize(neg_log_likelihood, beta0, jac=gradient,
                             method='BFGS', options={'maxiter': 1000, 'gtol': 1e-6})
        beta_std = opt_result.x

        # Transform back
        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)
    except Exception as e:
        print(f"  Logit estimation failed: {e}")
        return

    print(f"  Logit converged: {opt_result.success}")
    print(f"  Coefficients (original scale):")
    for i, name in enumerate(['const'] + x_vars):
        print(f"    {name:30s} {beta[i]:8.4f}")

    # --- Predict for eurozone members ---
    # Filter eurozone members post-1999 (using join year)
    ez_rows = []
    for iso3, join_yr in EUROZONE_JOIN.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= join_yr)
        ez_rows.append(df[mask])
    ez_df = pd.concat(ez_rows, ignore_index=True)
    ez_pred = ez_df[['iso3', 'year'] + x_vars].dropna()
    print(f"\n  Prediction sample (EZ members): {len(ez_pred)} obs, "
          f"{ez_pred['iso3'].nunique()} countries")

    X_pred = np.column_stack([np.ones(len(ez_pred)),
                               ez_pred[x_vars].values.astype(float)])
    z_pred = X_pred @ beta
    z_pred = np.clip(z_pred, -30, 30)
    p_peg = 1 / (1 + np.exp(-z_pred))

    ez_pred = ez_pred.copy()
    ez_pred['p_peg'] = p_peg
    ez_pred['predicted_peg'] = (p_peg >= 0.5).astype(int)

    # Summary by country
    print(f"\n  --- Counterfactual Regime Predictions for EMU Members ---")
    country_summary = ez_pred.groupby('iso3').agg(
        mean_p_peg=('p_peg', 'mean'),
        pct_predicted_peg=('predicted_peg', 'mean'),
        n_obs=('p_peg', 'count'),
    ).sort_values('mean_p_peg', ascending=False)

    for iso3, row in country_summary.iterrows():
        regime = "PEG" if row['pct_predicted_peg'] >= 0.5 else "FLOAT"
        print(f"    {iso3}: P(peg)={row['mean_p_peg']:.3f}, "
              f"predicted={regime} ({row['pct_predicted_peg']*100:.0f}% peg), "
              f"N={int(row['n_obs'])}")

    # Aggregate shares
    n_would_peg = (country_summary['pct_predicted_peg'] >= 0.5).sum()
    n_would_float = (country_summary['pct_predicted_peg'] < 0.5).sum()
    n_total = len(country_summary)
    print(f"\n  Summary: {n_would_peg}/{n_total} EMU members would choose PEG, "
          f"{n_would_float}/{n_total} would FLOAT")

    # Over-time trend
    yearly_pred = ez_pred.groupby('year').agg(
        mean_p_peg=('p_peg', 'mean'),
        pct_predicted_peg=('predicted_peg', 'mean'),
        n=('p_peg', 'count'),
    )

    # --- Write counterfactual table ---
    lines = ["# EMU Counterfactual: Predicted Regime Choice\n"]
    lines.append("Logit trained on OECD non-eurozone (post-1999): is_peg = f(Z, controls)\n")
    lines.append(f"Training sample: N={n}, Peg share={y_train.mean():.3f}\n")

    lines.append("## A. Country-Level Predictions\n")
    lines.append("| Country | Mean P(peg) | % Obs Predicted Peg | Predicted Regime | N |")
    lines.append("|:---|---:|---:|:---|---:|")
    for iso3, row in country_summary.iterrows():
        regime = "Peg" if row['pct_predicted_peg'] >= 0.5 else "Float"
        lines.append(f"| {iso3} | {row['mean_p_peg']:.3f} | "
                     f"{row['pct_predicted_peg']*100:.0f}% | {regime} | "
                     f"{int(row['n_obs'])} |")

    lines.append(f"\n**Summary**: {n_would_peg}/{n_total} EMU members predicted to PEG, "
                 f"{n_would_float}/{n_total} predicted to FLOAT\n")

    lines.append("## B. Over-Time Trend\n")
    lines.append("| Year | Mean P(peg) | % Predicted Peg | N |")
    lines.append("|---:|---:|---:|---:|")
    for year, row in yearly_pred.iterrows():
        lines.append(f"| {int(year)} | {row['mean_p_peg']:.3f} | "
                     f"{row['pct_predicted_peg']*100:.0f}% | {int(row['n'])} |")

    lines.append("\n*Logit estimated on OECD non-eurozone sample. "
                 "Predictions applied to eurozone members using their actual demographics.*")

    path = OUT_TABLES / "emu_counterfactual.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── Main ───────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 5: EUROZONE STRESS TEST — DEMOGRAPHICS UNDER A FIXED REGIME")
    print("=" * 70)

    df = pd.read_csv(DATA / "trilemma_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")
    print(f"Year range: {df['year'].min()}-{df['year'].max()}")

    # Eurozone and floater counts
    if 'eurozone' in df.columns:
        ez_count = df[df['eurozone'] == 1]['iso3'].nunique()
        print(f"Eurozone countries in panel: {ez_count}")
    if 'oecd_floater' in df.columns:
        fl_count = df[df['oecd_floater'] == 1]['iso3'].nunique()
        print(f"OECD floater countries in panel: {fl_count}")

    # 1. Eurozone-only CA regressions
    ez_df = eurozone_ca(df)

    # 2. Eurozone vs OECD floaters
    eurozone_vs_floaters(df, ez_df)

    # 3. Within-EMU dispersion
    emu_dispersion(df)

    # 4. Counterfactual regime choice
    emu_counterfactual(df)

    print("\n" + "=" * 70)
    print("PHASE 5 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
